{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Mixed precision matrix multiplication operation using cudnn FE\n",
    "This notebook shows how a mixed precision matmul operation can be done using cudnn."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cudnn-frontend/blob/main/samples/python/03_mixed_precision_matmul.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prerequisites for running on Colab\n",
    "This notebook requires an NVIDIA GPU H100 or newer. If `nvidia-smi` fails, go to Runtime -> Change runtime type -> Hardware accelerator and confirm a GPU is selected."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get_ipython().system('nvidia-smi')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If running on Colab, you will need to install the cudnn python interface."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get_ipython().system('pip install nvidia-cudnn-cu12')\n",
    "# get_ipython().system('pip install nvidia-cudnn-frontend')\n",
    "# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### General Setup\n",
    "We are going to call the cudnn through torch in this example. In general any dlpack tensor should work.\n",
    "cudnn handle is a per device handle used to initialize cudnn context.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cudnn\n",
    "import torch\n",
    "import sys\n",
    "\n",
    "handle = cudnn.create_handle()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Create input tensors and calculate reference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch, m, n, k = 16, 128, 128, 512\n",
    "\n",
    "# input data types can be different\n",
    "input_type_a = torch.int8\n",
    "input_type_b = torch.bfloat16\n",
    "output_type = torch.bfloat16\n",
    "\n",
    "# direct input data type for the matmul operation\n",
    "mma_data_type = torch.bfloat16\n",
    "\n",
    "# input tensors\n",
    "if input_type_a != torch.int8:\n",
    "    a = 2 * torch.randn(batch, m, k, dtype=input_type_a, device=\"cuda\") - 0.5\n",
    "else:\n",
    "    a = torch.randint(4, (batch, m, k), dtype=input_type_a, device=\"cuda\") - 1\n",
    "\n",
    "if input_type_b != torch.int8:\n",
    "    b_row_major = 3 * torch.randn(batch, k, n, dtype=input_type_b, device=\"cuda\") - 1.25\n",
    "else:\n",
    "    b_row_major = (\n",
    "        torch.randint(3, (batch, k, n), dtype=input_type_b, device=\"cuda\").contiguous()\n",
    "        - 2\n",
    "    )\n",
    "b = torch.as_strided(b_row_major, (batch, k, n), (n * k, 1, n))\n",
    "\n",
    "# reference output\n",
    "c_ref = torch.matmul(a.to(mma_data_type), b.to(mma_data_type)).to(output_type)\n",
    "\n",
    "# place holder for cudnn output\n",
    "c = torch.randn_like(c_ref, device=\"cuda\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Create cudnn graph and tensors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph = cudnn.pygraph()\n",
    "\n",
    "a_cudnn_tensor = graph.tensor_like(a)\n",
    "b_cudnn_tensor = graph.tensor_like(b)\n",
    "\n",
    "# cudnn will do the following conversion path: input_data_type -> compute_data_type -> output_data_type\n",
    "# compute_data_type can be int32 as well\n",
    "a_cudnn_tensor_casted = graph.identity(\n",
    "    input=a_cudnn_tensor, compute_data_type=cudnn.data_type.FLOAT\n",
    ")\n",
    "a_cudnn_tensor_casted.set_data_type(mma_data_type)\n",
    "\n",
    "# here we omit the code casting tensor b to the mma_data_type\n",
    "# since both of them are in bf16 data type in this example\n",
    "# user can also cast tensor b if it has a different input_type from the mma_data_type\n",
    "\n",
    "# compute_data_type should be set to int32 if the mma_data_type is int8\n",
    "c_cudnn_tensor = graph.matmul(\n",
    "    name=\"matmul\",\n",
    "    A=a_cudnn_tensor_casted,\n",
    "    B=b_cudnn_tensor,\n",
    "    compute_data_type=cudnn.data_type.FLOAT,\n",
    ")\n",
    "c_cudnn_tensor.set_name(\"c\").set_output(True).set_data_type(output_type)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Build the graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph.validate()\n",
    "graph.build_operation_graph()\n",
    "graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])\n",
    "graph.check_support()\n",
    "graph.build_plans()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Execute the code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "variant_pack = {\n",
    "    a_cudnn_tensor: a,\n",
    "    b_cudnn_tensor: b,\n",
    "    c_cudnn_tensor: c,\n",
    "}\n",
    "\n",
    "workspace = torch.empty(graph.get_workspace_size(), device=\"cuda\", dtype=torch.uint8)\n",
    "graph.execute(variant_pack, workspace)\n",
    "torch.cuda.synchronize()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.testing.assert_close(c, c_ref, rtol=5e-3, atol=5e-3)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "build_thunder",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
